{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "531467a2-5160-4073-a990-0d81d574b014",
   "metadata": {},
   "source": [
    "## (1) Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d9337043-4e7a-4b20-9d89-6c6257245334",
   "metadata": {},
   "outputs": [],
   "source": [
    "from model import Mamba, ModelArgs\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "# One of:\n",
    "#     'state-spaces/mamba-2.8b-slimpj'\n",
    "#     'state-spaces/mamba-2.8b'\n",
    "#     'state-spaces/mamba-1.4b'\n",
    "#     'state-spaces/mamba-790m'\n",
    "#     'state-spaces/mamba-370m'\n",
    "#     'state-spaces/mamba-130m'\n",
    "pretrained_model_name = 'state-spaces/mamba-370m'\n",
    "\n",
    "model = Mamba.from_pretrained(pretrained_model_name)\n",
    "tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b2efb17-37ad-472b-b029-9567acf17629",
   "metadata": {},
   "source": [
    "## (2) Generate Text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c4b2d62d-0d95-4a3f-bd98-aa37e3f26b39",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "def generate(model,\n",
    "             tokenizer,\n",
    "             prompt: str,\n",
    "             n_tokens_to_gen: int = 50,\n",
    "             sample: bool = True,\n",
    "             top_k: int = 40):\n",
    "    model.eval()\n",
    "    \n",
    "    input_ids = tokenizer(prompt, return_tensors='pt').input_ids\n",
    "    \n",
    "    for token_n in range(n_tokens_to_gen):\n",
    "        with torch.no_grad():\n",
    "            indices_to_input = input_ids\n",
    "            next_token_logits = model(indices_to_input)[:, -1]\n",
    "        \n",
    "        probs = F.softmax(next_token_logits, dim=-1)\n",
    "        (batch, vocab_size) = probs.shape\n",
    "        \n",
    "        if top_k is not None:\n",
    "            (values, indices) = torch.topk(probs, k=top_k)\n",
    "            probs[probs < values[:, -1, None]] = 0\n",
    "            probs = probs / probs.sum(axis=1, keepdims=True)\n",
    "        \n",
    "        if sample:\n",
    "            next_indices = torch.multinomial(probs, num_samples=1)\n",
    "        else:\n",
    "            next_indices = torch.argmax(probs, dim=-1)[:, None]\n",
    "        \n",
    "        input_ids = torch.cat([input_ids, next_indices], dim=1)\n",
    "\n",
    "    output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]\n",
    "    \n",
    "    return output_completions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ee877143-2042-4579-8042-a96db6200517",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)\n"
     ]
    }
   ],
   "source": [
    "print(generate(model, tokenizer, 'Mamba is the'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "65d70549-597f-49ca-9185-2184d2576f7d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "John: Hi!\n",
      "Sally: Hey!\n",
      "John: So, when's the wedding?\n",
      "Sally: We haven't decided.\n",
      "John: It's in September.\n",
      "Sally: Yeah, we were thinking July or\n",
      "August.\n",
      "John: I'm not too\n"
     ]
    }
   ],
   "source": [
    "print(generate(model, tokenizer, 'John: Hi!\\nSally:'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6d419fc9-066b-4818-812c-2f1952528bc6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The meaning of life is \n",
      "just this: It is the best you can do.\n",
      "\n",
      "--K.J.\n",
      "\n",
      "And finally: How to handle your emotions. \n",
      "\n",
      "<|endoftext|>Q:\n",
      "\n",
      "Error creating an EntityManager instance in JavaEE 7\n",
      "\n",
      "This is\n"
     ]
    }
   ],
   "source": [
    "print(generate(model, tokenizer, 'The meaning of life is '))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2b189e6e-6a96-4770-88cf-7c5de22cb321",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def reverse_string(text, result):\n",
      "        # find the position of the start of the string.\n",
      "        start = text.index(text[0:-1])\n",
      "        # find the position where the string begins changing.\n",
      "        end = text.index\n"
     ]
    }
   ],
   "source": [
    "print(generate(model, tokenizer, 'def reverse_string('))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be3afb51-5093-4c64-ac3f-43c2e6b20b10",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6531acc0-b18f-472a-8e99-cee64dd51cd8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0efe197-891a-4ab8-8cea-413d1fb1acda",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e99509b-df7b-4bac-b6a2-669f601ec1c8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
